# Licensed under a 3-clause BSD style license - see LICENSE.rst

from __future__ import annotations

from typing import TYPE_CHECKING

import numpy as np

if TYPE_CHECKING:
    from collections.abc import Callable
    from typing import TypeVar

    from numpy.typing import NDArray

    DT = TypeVar("DT", bound=np.generic)

__all__ = ["jackknife_resampling", "jackknife_stats"]
__doctest_requires__ = {"jackknife_stats": ["scipy"]}


def jackknife_resampling(data: NDArray[DT]) -> NDArray[DT]:
    """Performs jackknife resampling on numpy arrays.

    Jackknife resampling is a technique to generate 'n' deterministic samples
    of size 'n-1' from a measured sample of size 'n'. Basically, the i-th
    sample, (1<=i<=n), is generated by means of removing the i-th measurement
    of the original sample. Like the bootstrap resampling, this statistical
    technique finds applications in estimating variance, bias, and confidence
    intervals.

    Parameters
    ----------
    data : ndarray
        Original sample (1-D array) from which the jackknife resamples will be
        generated.

    Returns
    -------
    resamples : ndarray
        The i-th row is the i-th jackknife sample, i.e., the original sample
        with the i-th measurement deleted.

    References
    ----------
    .. [1] McIntosh, Avery. "The Jackknife Estimation Method".
        <https://arxiv.org/abs/1606.00497>

    .. [2] Efron, Bradley. "The Jackknife, the Bootstrap, and other
        Resampling Plans". Technical Report No. 63, Division of Biostatistics,
        Stanford University, December, 1980.

    .. [3] Jackknife resampling <https://en.wikipedia.org/wiki/Jackknife_resampling>
    """
    n = data.shape[0]
    if n <= 0:
        raise ValueError("data must contain at least one measurement.")

    resamples = np.empty([n, n - 1])

    for i in range(n):
        resamples[i] = np.delete(data, i)

    return resamples


def jackknife_stats(
    data: NDArray,
    statistic: Callable,
    confidence_level: float = 0.95,
) -> tuple[float | NDArray, float | NDArray, float | NDArray, NDArray]:
    """Performs jackknife estimation on the basis of jackknife resamples.

    This function requires `SciPy <https://www.scipy.org/>`_ to be installed.

    Parameters
    ----------
    data : ndarray
        Original sample (1-D array).
    statistic : function
        Any function (or vector of functions) on the basis of the measured
        data, e.g, sample mean, sample variance, etc. The jackknife estimate of
        this statistic will be returned.
    confidence_level : float, optional
        Confidence level for the confidence interval of the Jackknife estimate.
        Must be a real-valued number in (0,1). Default value is 0.95.

    Returns
    -------
    estimate : float or `~numpy.ndarray`
        The i-th element is the bias-corrected "jackknifed" estimate.

    bias : float or `~numpy.ndarray`
        The i-th element is the jackknife bias.

    std_err : float or `~numpy.ndarray`
        The i-th element is the jackknife standard error.

    conf_interval : ndarray
        If ``statistic`` is single-valued, the first and second elements are
        the lower and upper bounds, respectively. If ``statistic`` is
        vector-valued, each column corresponds to the confidence interval for
        each component of ``statistic``. The first and second rows contain the
        lower and upper bounds, respectively.

    Examples
    --------
    1. Obtain Jackknife resamples:

    >>> import numpy as np
    >>> from astropy.stats import jackknife_resampling
    >>> from astropy.stats import jackknife_stats
    >>> data = np.array([1,2,3,4,5,6,7,8,9,0])
    >>> resamples = jackknife_resampling(data)
    >>> resamples
    array([[2., 3., 4., 5., 6., 7., 8., 9., 0.],
           [1., 3., 4., 5., 6., 7., 8., 9., 0.],
           [1., 2., 4., 5., 6., 7., 8., 9., 0.],
           [1., 2., 3., 5., 6., 7., 8., 9., 0.],
           [1., 2., 3., 4., 6., 7., 8., 9., 0.],
           [1., 2., 3., 4., 5., 7., 8., 9., 0.],
           [1., 2., 3., 4., 5., 6., 8., 9., 0.],
           [1., 2., 3., 4., 5., 6., 7., 9., 0.],
           [1., 2., 3., 4., 5., 6., 7., 8., 0.],
           [1., 2., 3., 4., 5., 6., 7., 8., 9.]])
    >>> resamples.shape
    (10, 9)

    2. Obtain Jackknife estimate for the mean, its bias, its standard error,
    and its 95% confidence interval:

    >>> test_statistic = np.mean
    >>> estimate, bias, stderr, conf_interval = jackknife_stats(
    ...     data, test_statistic, 0.95)
    >>> estimate
    np.float64(4.5)
    >>> bias
    np.float64(0.0)
    >>> stderr  # doctest: +FLOAT_CMP
    np.float64(0.95742710775633832)
    >>> conf_interval
    array([2.62347735,  6.37652265])

    3. Example for two estimates

    >>> test_statistic = lambda x: (np.mean(x), np.var(x))
    >>> estimate, bias, stderr, conf_interval = jackknife_stats(
    ...     data, test_statistic, 0.95)
    >>> estimate
    array([4.5       ,  9.16666667])
    >>> bias
    array([ 0.        , -0.91666667])
    >>> stderr
    array([0.95742711,  2.69124476])
    >>> conf_interval
    array([[ 2.62347735,   3.89192387],
           [ 6.37652265,  14.44140947]])

    IMPORTANT: Note that confidence intervals are given as columns
    """
    # jackknife confidence interval
    if not (0 < confidence_level < 1):
        raise ValueError("confidence level must be in (0, 1).")

    # make sure original data is proper
    n = data.shape[0]
    if n <= 0:
        raise ValueError("data must contain at least one measurement.")

    # Only import scipy if inputs are valid
    from scipy.special import erfinv

    resamples = jackknife_resampling(data)

    stat_data = statistic(data)
    jack_stat = np.apply_along_axis(statistic, 1, resamples)
    mean_jack_stat = np.mean(jack_stat, axis=0)

    # jackknife bias
    bias = (n - 1) * (mean_jack_stat - stat_data)

    # jackknife standard error
    std_err = np.sqrt(
        (n - 1)
        * np.mean((jack_stat - mean_jack_stat) * (jack_stat - mean_jack_stat), axis=0)
    )

    # bias-corrected "jackknifed estimate"
    estimate = stat_data - bias

    z_score = np.sqrt(2.0) * erfinv(confidence_level)
    conf_interval = estimate + z_score * np.array((-std_err, std_err))

    return estimate, bias, std_err, conf_interval
